package cn.iocoder.yudao.module.system.util;

import lombok.extern.slf4j.Slf4j;

import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * 性能监控工具类
 * 用于监控招生计划过滤等操作的性能指标
 */
@Slf4j
public class PerformanceMonitor {

    /**
     * 性能指标统计
     */
    public static class PerformanceMetrics {
        private final AtomicInteger totalCalls = new AtomicInteger(0);
        private final AtomicLong totalTime = new AtomicLong(0);
        private final AtomicInteger successCalls = new AtomicInteger(0);
        private final AtomicInteger failureCalls = new AtomicInteger(0);
        private final AtomicLong minTime = new AtomicLong(Long.MAX_VALUE);
        private final AtomicLong maxTime = new AtomicLong(0);

        public void recordCall(long duration, boolean success) {
            totalCalls.incrementAndGet();
            totalTime.addAndGet(duration);
            
            if (success) {
                successCalls.incrementAndGet();
            } else {
                failureCalls.incrementAndGet();
            }
            
            // 更新最小时间
            long currentMin = minTime.get();
            while (duration < currentMin && !minTime.compareAndSet(currentMin, duration)) {
                currentMin = minTime.get();
            }
            
            // 更新最大时间
            long currentMax = maxTime.get();
            while (duration > currentMax && !maxTime.compareAndSet(currentMax, duration)) {
                currentMax = maxTime.get();
            }
        }

        public double getAverageTime() {
            int calls = totalCalls.get();
            return calls > 0 ? (double) totalTime.get() / calls : 0;
        }

        public double getSuccessRate() {
            int total = totalCalls.get();
            return total > 0 ? (double) successCalls.get() / total * 100 : 0;
        }

        public int getTotalCalls() {
            return totalCalls.get();
        }

        public long getTotalTime() {
            return totalTime.get();
        }

        public int getSuccessCalls() {
            return successCalls.get();
        }

        public int getFailureCalls() {
            return failureCalls.get();
        }

        public long getMinTime() {
            long min = minTime.get();
            return min == Long.MAX_VALUE ? 0 : min;
        }

        public long getMaxTime() {
            return maxTime.get();
        }

        @Override
        public String toString() {
            return String.format(
                "PerformanceMetrics{totalCalls=%d, avgTime=%.2fms, successRate=%.1f%%, minTime=%dms, maxTime=%dms}",
                getTotalCalls(), getAverageTime(), getSuccessRate(), getMinTime(), getMaxTime()
            );
        }
    }

    private static final ConcurrentHashMap<String, PerformanceMetrics> metrics = new ConcurrentHashMap<>();

    /**
     * 记录操作性能
     */
    public static void recordOperation(String operationName, long duration, boolean success) {
        metrics.computeIfAbsent(operationName, k -> new PerformanceMetrics())
               .recordCall(duration, success);
    }

    /**
     * 获取操作的性能指标
     */
    public static PerformanceMetrics getMetrics(String operationName) {
        return metrics.get(operationName);
    }

    /**
     * 打印所有性能指标
     */
    public static void printAllMetrics() {
        log.info("=== 性能监控报告 ===");
        metrics.forEach((operation, metric) -> {
            log.info("{}: {}", operation, metric);
        });
    }

    /**
     * 清除指定操作的性能指标
     */
    public static void clearMetrics(String operationName) {
        metrics.remove(operationName);
    }

    /**
     * 清除所有性能指标
     */
    public static void clearAllMetrics() {
        metrics.clear();
    }

    /**
     * 性能监控装饰器
     */
    public static class PerformanceTimer {
        private final String operationName;
        private final long startTime;

        public PerformanceTimer(String operationName) {
            this.operationName = operationName;
            this.startTime = System.currentTimeMillis();
        }

        public void recordSuccess() {
            long duration = System.currentTimeMillis() - startTime;
            recordOperation(operationName, duration, true);
        }

        public void recordFailure() {
            long duration = System.currentTimeMillis() - startTime;
            recordOperation(operationName, duration, false);
        }

        public void recordResult(boolean success) {
            long duration = System.currentTimeMillis() - startTime;
            recordOperation(operationName, duration, success);
        }
    }

    /**
     * 创建性能计时器
     */
    public static PerformanceTimer startTimer(String operationName) {
        return new PerformanceTimer(operationName);
    }

    /**
     * 批量过滤性能统计
     */
    public static class BatchFilterStats {
        private final AtomicInteger totalBatches = new AtomicInteger(0);
        private final AtomicInteger totalItems = new AtomicInteger(0);
        private final AtomicInteger filteredItems = new AtomicInteger(0);
        private final AtomicLong totalTime = new AtomicLong(0);

        public void recordBatch(int batchSize, int filteredCount, long duration) {
            totalBatches.incrementAndGet();
            totalItems.addAndGet(batchSize);
            filteredItems.addAndGet(filteredCount);
            totalTime.addAndGet(duration);
        }

        public double getFilterRate() {
            int total = totalItems.get();
            return total > 0 ? (double) filteredItems.get() / total * 100 : 0;
        }

        public double getAverageTimePerItem() {
            int total = totalItems.get();
            return total > 0 ? (double) totalTime.get() / total : 0;
        }

        public void printStats() {
            log.info("批量过滤统计: 总批次={}, 总项目={}, 过滤后={}, 过滤率={:.1f}%, 平均耗时={:.2f}ms/项",
                    totalBatches.get(), totalItems.get(), filteredItems.get(), 
                    getFilterRate(), getAverageTimePerItem());
        }
    }

    private static final BatchFilterStats batchFilterStats = new BatchFilterStats();

    /**
     * 记录批量过滤统计
     */
    public static void recordBatchFilter(int batchSize, int filteredCount, long duration) {
        batchFilterStats.recordBatch(batchSize, filteredCount, duration);
    }

    /**
     * 获取批量过滤统计
     */
    public static BatchFilterStats getBatchFilterStats() {
        return batchFilterStats;
    }

    /**
     * 打印批量过滤统计
     */
    public static void printBatchFilterStats() {
        batchFilterStats.printStats();
    }
}
